{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7edcc568",
   "metadata": {},
   "source": [
    "# Assess predictions on text classification DBPedia data with a huggingface transformers model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad3b6312",
   "metadata": {},
   "source": [
    "This notebook demonstrates the use of the `responsibleai` API to assess a text classification huggingface transformers model trained on the DBPedia dataset. It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b55006f",
   "metadata": {},
   "source": [
    "* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n",
    "    * [Load Model and Data](#Load-Model-and-Data)\n",
    "    * [Create Model and Data Insights](#Create-Model-and-Data-Insights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "739e0c4d",
   "metadata": {},
   "source": [
    "## Launch Responsible AI Toolbox"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05a7681f",
   "metadata": {},
   "source": [
    "The following section examines the code necessary to create datasets and a model. It then generates insights using the `responsibleai` API that can be visually analyzed."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79128ba0",
   "metadata": {},
   "source": [
    "### Load Model and Data\n",
    "*The following section can be skipped. It loads a dataset and trains a model for illustrative purposes.*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47a217da",
   "metadata": {},
   "source": [
    "First we import all necessary dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75ef9e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import pandas as pd\n",
    "import zipfile\n",
    "from transformers import (AutoModelForSequenceClassification, AutoTokenizer,\n",
    "                          pipeline)\n",
    "\n",
    "from raiutils.common.retries import retry_function\n",
    "\n",
    "try:\n",
    "    from urllib import urlretrieve\n",
    "except ImportError:\n",
    "    from urllib.request import urlretrieve"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0804d4d",
   "metadata": {},
   "source": [
    "Next we load the DBPedia dataset from huggingface datasets.  Note we use only 6 examples and 8 additional error instances here since it can take some time to compute explanations, especially on CPU.  You can increase the NUM_TEST_SAMPLES to 100 or more to get a more interesting dashboard."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ae1bf09",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bump up the number of examples to 100 or greater to view more\n",
    "# information, but it may take longer to compute\n",
    "NUM_TEST_SAMPLES = 6\n",
    "\n",
    "def load_dataset(split):\n",
    "    dataset = datasets.load_dataset(\"DeveloperOats/DBPedia_Classes\", split=split)\n",
    "    return pd.DataFrame({\"text\": dataset[\"text\"], \"l1\": dataset[\"l1\"]})\n",
    "\n",
    "pd_valid_data = load_dataset(\"test\")\n",
    "\n",
    "def rename_label_column(dataset):\n",
    "    dataset[\"label\"] = dataset[\"l1\"]\n",
    "    dataset = dataset.drop(columns=\"l1\")\n",
    "    return dataset\n",
    "\n",
    "pd_valid_data = rename_label_column(pd_valid_data)\n",
    "\n",
    "START_INDEX = 0\n",
    "test_data = pd_valid_data[:NUM_TEST_SAMPLES]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f2631e7",
   "metadata": {},
   "source": [
    "Add some known error instances to make the data more interesting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "691e8fb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "error_indices = [101, 319, 391, 414, 644, 894, 1078, 1209]\n",
    "test_data = test_data.append(pd_valid_data.iloc[error_indices]).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a3e73da",
   "metadata": {},
   "source": [
    "Fetch a pre-trained huggingface model on the DBPedia dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36e20709",
   "metadata": {},
   "outputs": [],
   "source": [
    "DBPEDIA_MODEL_NAME = \"dbpedia_model\"\n",
    "NUM_LABELS = 9\n",
    "\n",
    "class FetchModel(object):\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "    def fetch(self):\n",
    "        zipfilename = DBPEDIA_MODEL_NAME + '.zip'\n",
    "        url = ('https://publictestdatasets.blob.core.windows.net/models/' +\n",
    "               DBPEDIA_MODEL_NAME + '.zip')\n",
    "        urlretrieve(url, zipfilename)\n",
    "        with zipfile.ZipFile(zipfilename, 'r') as unzip:\n",
    "            unzip.extractall(DBPEDIA_MODEL_NAME)\n",
    "\n",
    "def retrieve_dbpedia_model():\n",
    "    fetcher = FetchModel()\n",
    "    action_name = \"Model download\"\n",
    "    err_msg = \"Failed to download model\"\n",
    "    max_retries = 4\n",
    "    retry_delay = 60\n",
    "    retry_function(fetcher.fetch, action_name, err_msg,\n",
    "                   max_retries=max_retries,\n",
    "                   retry_delay=retry_delay)\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        DBPEDIA_MODEL_NAME, num_labels=NUM_LABELS)\n",
    "    return model\n",
    "\n",
    "model = retrieve_dbpedia_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea0eef05",
   "metadata": {},
   "source": [
    "Load the model and tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c93f3d7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "device = -1\n",
    "if device >= 0:\n",
    "    model = model.cuda()\n",
    "\n",
    "# build a pipeline object to do predictions\n",
    "pred = pipeline(\n",
    "    \"text-classification\",\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    device=device,\n",
    "    return_all_scores=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d54a252",
   "metadata": {},
   "source": [
    "Define the encoded classes, which needs to come from the original trained model, and display the number of errors on the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2178e47f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ml_wrappers import wrap_model\n",
    "wrapped_model = wrap_model(pred, test_data, 'text_classification')\n",
    "\n",
    "encoded_classes = ['Agent', 'Device', 'Event', 'Place', 'Species',\n",
    "                   'SportsSeason', 'TopicalConcept', 'UnitOfWork',\n",
    "                   'Work']\n",
    "labels = [encoded_classes.index(y) for y in test_data['label'].tolist()]\n",
    "\n",
    "print(\"number of errors on test dataset: \" + str(sum(wrapped_model.predict(test_data['text'].tolist()) != labels)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d5a9c6c",
   "metadata": {},
   "source": [
    "### Create Model and Data Insights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a0763d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from responsibleai_text import RAITextInsights, ModelTask\n",
    "from raiwidgets import ResponsibleAIDashboard"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c4811f9",
   "metadata": {},
   "source": [
    "To use Responsible AI Dashboard, initialize a RAITextInsights object upon which different components can be loaded.\n",
    "\n",
    "RAITextInsights accepts the model, the test dataset, the classes and the task type as its arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc8021d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights = RAITextInsights(pred, test_data,\n",
    "                               \"label\",\n",
    "                               classes=encoded_classes,\n",
    "                               task_type=ModelTask.TEXT_CLASSIFICATION)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "449d15e7",
   "metadata": {},
   "source": [
    "Add the components of the toolbox for model assessment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "526aca04",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights.explainer.add()\n",
    "rai_insights.error_analysis.add()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e02a8636",
   "metadata": {},
   "source": [
    "Once all the desired components have been loaded, compute insights on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f002d52",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights.compute()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16ed47d0",
   "metadata": {},
   "source": [
    "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b63dbfa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ResponsibleAIDashboard(rai_insights)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
